"""
OSQP Solver pure python implementation: low level module
"""
from __future__ import print_function
from builtins import range
from builtins import object
import numpy as np
import scipy.sparse as spspa
import scipy.sparse.linalg as spla
import numpy.linalg as la
import time

# Solver Constants
OSQP_DUAL_INFEASIBLE_INACCURATE = 4
OSQP_PRIMAL_INFEASIBLE_INACCURATE = 3
OSQP_SOLVED_INACCURATE = 2
OSQP_SOLVED = 1
OSQP_MAX_ITER_REACHED = -2
OSQP_PRIMAL_INFEASIBLE = -3
OSQP_DUAL_INFEASIBLE = -4
OSQP_NON_CVX = -7
OSQP_UNSOLVED = -10

# Parameter bounds
RHO_MIN = 1e-06
RHO_MAX = 1e06
RHO_EQ_OVER_RHO_INEQ = 1e03
RHO_TOL = 1e-04


# Printing interval
PRINT_INTERVAL = 200

# OSQP Infinity
OSQP_INFTY = 1e30

# OSQP Nan
OSQP_NAN = np.nan

# Linear system solver options
QDLDL_SOLVER = 0

# Scaling
MIN_SCALING = 1e-04
MAX_SCALING = 1e04


class workspace(object):
    """
    OSQP solver workspace

    Attributes
    ----------
    data                   - scaled QP problem
    info                   - solver information
    linsys_solver          - structure for linear system solution
    scaling                - scaling matrices
    settings               - settings structure
    solution               - solution structure


    Additional workspace variables
    ------------------------------
    first_run              - flag to indicate if it is the first run
    clear_update_time      - flag to indicate if update_time should be cleared
    timer                  - saved time instant for timing purposes
    x                      - primal iterate
    x_prev                 - previous primal iterate
    xz_tilde               - x_tilde and z_tilde iterates stacked together
    y                      - dual iterate
    z                      - z iterate
    z_prev                 - previous z iterate

    Vectorized rho parameter
    ------------------------
    rho_vec                - vector of rho values for each constraint
    rho_inv_vec            - vector of reciprocal rho values
    constr_type            - type of constraints: loose (-1), eq (1), ineq (0)

    Primal infeasibility related workspace variables
    ------------------------------------------------
    delta_y                - difference of consecutive y
    Atdelta_y              - A' * delta_y

    Dual infeasibility related workspace variables
    ----------------------------------------------
    delta_x                - difference of consecutive x
    Pdelta_x               - P * delta_x
    Adelta_x               - A * delta_x

    """


class problem(object):
    """
    QP problem of the form
        minimize	1/2 x' P x + q' x
        subject to	l <= A x <= u

    Attributes
    ----------
    P, q
    A, l, u
    """

    def __init__(self, dims, Pdata, Pindices, Pindptr, q, Adata, Aindices, Aindptr, l, u):
        # Set problem dimensions
        (self.n, self.m) = dims

        # Set problem data
        self.P = spspa.csc_matrix((Pdata, Pindices, Pindptr), shape=(self.n, self.n))
        self.q = q
        self.A = spspa.csc_matrix((Adata, Aindices, Aindptr), shape=(self.m, self.n))
        self.l = l if l is not None else -np.inf * np.ones(self.m)
        self.u = u if u is not None else np.inf * np.ones(self.m)


class settings(object):
    """
    OSQP solver settings

    Attributes
    ----------
    -> These cannot be changed without running setup
    sigma    [1e-06]           - Regularization parameter for polish
    scaling  [10]            - Scaling/Equilibration iterations (0 disabled)

    -> These can be changed without running setup
    rho  [1.6]                 - Step in ADMM procedure
    max_iter [4000]                     - Maximum number of iterations
    eps_abs  [1e-05]                    - Absolute tolerance
    eps_rel  [1e-05]                    - Relative tolerance
    eps_prim_inf  [1e-06]                    - Primal infeasibility tolerance
    eps_dual_inf  [1e-06]                    - Dual infeasibility tolerance
    alpha [1.6]                         - Relaxation parameter
    delta [1.0]                         - Regularization parameter for polish
    verbose  [True]                     - Verbosity
    scaled_termination [False]             - Evalute scaled termination criteria
    check_termination  [True]             - Interval for termination checking
    warm_start [False]                  - Reuse solution from previous solve
    polish  [False]                     - Solution polish
    polish_refine_iter  [3]                - Iterative refinement iterations
    """

    def __init__(self, **kwargs):
        self.rho = kwargs.pop('rho', 0.1)
        self.sigma = kwargs.pop('sigma', 1e-06)
        self.scaling = kwargs.pop('scaling', 10)
        self.max_iter = kwargs.pop('max_iter', 4000)
        self.eps_abs = kwargs.pop('eps_abs', 1e-3)
        self.eps_rel = kwargs.pop('eps_rel', 1e-3)
        self.eps_prim_inf = kwargs.pop('eps_prim_inf', 1e-4)
        self.eps_dual_inf = kwargs.pop('eps_dual_inf', 1e-4)
        self.alpha = kwargs.pop('alpha', 1.6)
        self.linsys_solver = kwargs.pop('linsys_solver', QDLDL_SOLVER)
        self.delta = kwargs.pop('delta', 1e-6)
        self.verbose = kwargs.pop('verbose', True)
        self.scaled_termination = kwargs.pop('scaled_termination', False)
        self.check_termination = kwargs.pop('check_termination', True)
        self.warm_start = kwargs.pop('warm_start', True)
        self.polish = kwargs.pop('polish', False)
        self.polish_refine_iter = kwargs.pop('polish_refine_iter', 3)
        self.adaptive_rho = kwargs.pop('adaptive_rho', True)
        self.adaptive_rho_interval = kwargs.pop('adaptive_rho_interval', 200)
        self.adaptive_rho_tolerance = kwargs.pop('adaptive_rho_tolerance', 5)
        self.adaptive_rho_fraction = kwargs.pop('adaptive_rho_fraction', 0.7)


class scaling(object):
    """
    Matrices for diagonal scaling

    Attributes
    ----------
    D        - matrix in R^{n \\times n}
    E        - matrix in R^{m \\times n}
    Dinv     - inverse of D
    Einv     - inverse of E
    c        - cost scaling
    cinv    - inverse of cost scaling
    """

    def __init__(self):
        self.D = None
        self.E = None
        self.Dinv = None
        self.Einv = None
        self.c = None
        self.cinv = None


class linesearch(object):
    """
    Vectors obtained from line search between the ADMM and the polished
    solution

    Attributes
    ----------
    X     - matrix in R^{N \\times n}
    Z     - matrix in R^{N \\times m}
    Y     - matrix in R^{N \\times m}
    t     - vector in R^N
    """

    def __init__(self):
        self.X = None
        self.Z = None
        self.Y = None
        self.t = None


class solution(object):
    """
    Solver solution vectors z, u
    """

    def __init__(self):
        self.x = None
        self.y = None


class info(object):
    """
    Solver information

    Attributes
    ----------
    iter            - number of iterations taken
    status          - status string, e.g. 'Solved'
    status_val      - status as OSQPInt, defined in constants.h
    status_polish   - polish status: successful (1), not (0)
    obj_val         - primal objective
    pri_res         - norm of primal residual
    dua_res         - norm of dual residual
    setup_time      - time taken for setup phase (seconds)
    solve_time      - time taken for solve phase (seconds)
    update_time     - time taken for update phase (seconds)
    polish_time     - time taken for polish phase (seconds)
    run_time        - total time  (seconds)
    rho_updates     - number of rho updates
    rho_estimate    - optimal rho estimate
    """

    def __init__(self):
        self.iter = 0
        self.status_val = OSQP_UNSOLVED
        self.status = 'Unsolved'
        self.status_polish = 0
        self.update_time = 0.0
        self.polish_time = 0.0
        self.rho_updates = 0.0


class polish(object):
    """
    Polishing structure containing active constraints at the solution

    Attributes
    ----------
    ind_low         - indices of lower-active constraints
    ind_upp         - indices of upper-active constraints
    n_low           - number of lower-active constraints
    n_upp           - number of upper-active constraints
    Ared            - Part of A containing only active rows
    x               - polished x
    z               - polished z
    y               - polished y
    """

    def __init__(self):
        self.ind_low = None
        self.ind_upp = None
        self.n_low = None
        self.n_upp = None
        self.Ared = None
        self.x = None
        self.z = None
        self.y = None


class linsys_solver(object):
    """
    Linear systems solver
    """

    def __init__(self, work):
        """
        Initialize structure for KKT system solution
        """
        # Construct reduced KKT matrix
        KKT = spspa.vstack(
            [
                spspa.hstack(
                    [
                        work.data.P + work.settings.sigma * spspa.eye(work.data.n),
                        work.data.A.T,
                    ]
                ),
                spspa.hstack([work.data.A, -spspa.diags(work.rho_inv_vec)]),
            ]
        )

        # Initialize structure
        self.kkt_factor = spla.splu(KKT.tocsc())
        #  self.lu, self.piv = sp.linalg.lu_factor(KKT.todense())

    def solve(self, rhs):
        """
        Solve linear system with given factorization
        """
        return self.kkt_factor.solve(rhs)
        #  return sp.linalg.lu_solve((self.lu, self.piv), rhs)


class results(object):
    """
    Results structure

    Attributes
    ----------
    x           - primal solution
    y           - dual solution
    info        - info structure
    """

    def __init__(self, solution, info, linesearch):
        self.x = solution.x
        self.y = solution.y
        self.info = info
        self.linesearch = linesearch


class OSQP(object):
    """OSQP solver lower level interface
    Attributes
    ----------
    work    - workspace
    """

    def __init__(self):
        self._version = '0.6.2.post0'

    @property
    def version(self):
        """Return solver version"""
        return self._version

    def _norm_KKT_cols(self, P, A):
        """
        Compute the norm of the KKT matrix from P and A
        """

        # First half
        norm_P_cols = spspa.linalg.norm(P, np.inf, axis=0)
        norm_A_cols = spspa.linalg.norm(A, np.inf, axis=0)
        norm_first_half = np.maximum(norm_P_cols, norm_A_cols)

        # Second half (norm cols of A')
        norm_second_half = spspa.linalg.norm(A, np.inf, axis=1)

        return np.hstack((norm_first_half, norm_second_half))

    def _limit_scaling(self, norm_vec):
        """
        Norm vector for scaling
        """

        if isinstance(norm_vec, (list, tuple, np.ndarray)):  # Array
            n = len(norm_vec)
            new_norm_vec = np.zeros(n)

            for i in range(n):
                if norm_vec[i] < MIN_SCALING:
                    new_norm_vec[i] = 1.0
                elif norm_vec[i] > MAX_SCALING:
                    new_norm_vec[i] = MAX_SCALING
                else:
                    new_norm_vec[i] = norm_vec[i]
        else:  # Scalar
            if norm_vec < MIN_SCALING:
                new_norm_vec = 1.0
            elif norm_vec > MAX_SCALING:
                new_norm_vec = MAX_SCALING
            else:
                new_norm_vec = norm_vec

        return new_norm_vec

    def scale_data(self):
        """
        Perform symmetric diagonal scaling via equilibration
        """
        n = self.work.data.n
        m = self.work.data.m

        # Initialize scaling
        s_temp = np.ones(n + m)
        c = 1.0  # Cost scaling

        # Define data
        P = self.work.data.P
        q = self.work.data.q
        A = self.work.data.A
        l = self.work.data.l
        u = self.work.data.u

        # Initialize scaler matrices
        D = spspa.eye(n)
        if m == 0:
            # spspa.diags() will throw an error if fed with an empty array
            E = spspa.csc_matrix((0, 0))
        else:
            E = spspa.eye(m)

        # Iterate Scaling
        for i in range(self.work.settings.scaling):
            # First Step Ruiz
            norm_cols = self._norm_KKT_cols(P, A)
            norm_cols = self._limit_scaling(norm_cols)  # Limit scaling
            sqrt_norm_cols = np.sqrt(norm_cols)  # Compute sqrt
            s_temp = np.reciprocal(sqrt_norm_cols)  # Elementwise recipr

            # Obtain Scaler Matrices
            D_temp = spspa.diags(s_temp[: self.work.data.n])
            if m == 0:
                # spspa.diags() will throw an error if fed with an empty array
                E_temp = spspa.csc_matrix((0, 0))
            else:
                E_temp = spspa.diags(s_temp[self.work.data.n :])

            # Scale data in place
            P = D_temp.dot(P.dot(D_temp)).tocsc()
            A = E_temp.dot(A.dot(D_temp)).tocsc()
            q = D_temp.dot(q)
            l = E_temp.dot(l)
            u = E_temp.dot(u)

            # Update equilibration matrices D and E
            D = D_temp.dot(D)
            E = E_temp.dot(E)

            # Second Step cost normalization
            norm_P_cols = spla.norm(P, np.inf, axis=0).mean()
            inf_norm_q = np.linalg.norm(q, np.inf)
            inf_norm_q = self._limit_scaling(inf_norm_q)
            scale_cost = np.maximum(inf_norm_q, norm_P_cols)
            scale_cost = self._limit_scaling(scale_cost)
            scale_cost = 1.0 / scale_cost

            # scale_cost = 1. / np.maximum(np.minimum(
            #     scale_cost, MAX_SCALING), MIN_SCALING)
            # print("trace P", P.todense().trace()[0, 0])
            # print("sum_norm_P_cols", spla.norm(P, np.inf, axis=0).sum())
            # print("norm_P_cols", norm_P_cols)
            # print("inf_norm_q", inf_norm_q)
            # print("Scale cost = %.2e" % scale_cost)

            # norm_cost = self._limit_scaling(norm_cost)
            c_temp = scale_cost

            # c_temp = 1.0

            # Normalize cost
            P = c_temp * P
            q = c_temp * q

            # Update scaling
            c = c_temp * c

        if self.work.settings.verbose:
            print('Final cost scaling = %.10f' % c)

        # Assign scaled problem
        self.work.data = problem(
            (n, m),
            P.data,
            P.indices,
            P.indptr,
            q,
            A.data,
            A.indices,
            A.indptr,
            l,
            u,
        )

        # Assign scaling matrices
        self.work.scaling = scaling()
        self.work.scaling.D = D
        self.work.scaling.Dinv = spspa.diags(np.reciprocal(D.diagonal()))
        self.work.scaling.E = E
        if m == 0:
            self.work.scaling.Einv = E
        else:
            self.work.scaling.Einv = spspa.diags(np.reciprocal(E.diagonal()))
        self.work.scaling.c = c
        self.work.scaling.cinv = 1.0 / c

    def set_rho_vec(self):
        """
        Set values of rho vector based on constraint types
        """
        self.work.settings.rho = np.minimum(np.maximum(self.work.settings.rho, RHO_MIN), RHO_MAX)

        # Find indices of loose bounds, equality constr and one-sided constr
        loose_ind = np.where(
            np.logical_and(
                self.work.data.l < -OSQP_INFTY * MIN_SCALING,
                self.work.data.u > OSQP_INFTY * MIN_SCALING,
            )
        )[0]
        eq_ind = np.where(self.work.data.u - self.work.data.l < RHO_TOL)[0]
        ineq_ind = np.setdiff1d(np.setdiff1d(np.arange(self.work.data.m), loose_ind), eq_ind)

        # Type of constraints
        self.work.constr_type[loose_ind] = -1
        self.work.constr_type[eq_ind] = 1
        self.work.constr_type[ineq_ind] = 0

        self.work.rho_vec[loose_ind] = RHO_MIN
        self.work.rho_vec[eq_ind] = RHO_EQ_OVER_RHO_INEQ * self.work.settings.rho
        self.work.rho_vec[ineq_ind] = self.work.settings.rho

        self.work.rho_inv_vec = np.reciprocal(self.work.rho_vec)

    def update_rho_vec(self):
        """
        Update values of rho_vec and refactor if constraints change.
        """
        # Find indices of loose bounds, equality constr and one-sided constr
        loose_ind = np.where(
            np.logical_and(
                self.work.data.l < -OSQP_INFTY * MIN_SCALING,
                self.work.data.u > OSQP_INFTY * MIN_SCALING,
            )
        )[0]
        eq_ind = np.where(self.work.data.u - self.work.data.l < RHO_TOL)[0]
        ineq_ind = np.setdiff1d(np.setdiff1d(np.arange(self.work.data.m), loose_ind), eq_ind)

        # Find indices of current constraint types
        old_loose_ind = np.where(self.work.constr_type == -1)
        old_eq_ind = np.where(self.work.constr_type == 1)
        old_ineq_ind = np.where(self.work.constr_type == 0)

        # Check if type of any constraint changed
        constr_type_changed = (
            (loose_ind != old_loose_ind).any() or (eq_ind != old_eq_ind).any() or (ineq_ind != old_ineq_ind).any()
        )

        # Update type of constraints
        self.work.constr_type[loose_ind] = -1
        self.work.constr_type[eq_ind] = 1
        self.work.constr_type[ineq_ind] = 0

        self.work.rho_vec[loose_ind] = RHO_MIN
        self.work.rho_vec[eq_ind] = RHO_EQ_OVER_RHO_INEQ * self.work.settings.rho
        self.work.rho_vec[ineq_ind] = self.work.settings.rho

        self.work.rho_inv_vec = np.reciprocal(self.work.rho_vec)

        if constr_type_changed:
            self.work.linsys_solver = linsys_solver(self.work)

    def print_setup_header(self, data, settings):
        """Print solver header"""
        print('--------------------------------------------------------------')
        print('         OSQP v%s  -  Operator Splitting QP Solver' % self.version)
        print('                 Pure Python Implementation')
        print('        (c) Bartolomeo Stellato, Goran Banjac')
        print('      University of Oxford  -  Stanford University 2017')
        print('--------------------------------------------------------------')

        print('problem:  variables n = %d, constraints m = %d' % (data.n, data.m))
        nnz = self.work.data.P.nnz + self.work.data.A.nnz
        print('          nnz(P) + nnz(A) = %i' % nnz)
        print('settings: ', end='')
        if settings.linsys_solver == QDLDL_SOLVER:
            print('linear system solver = qdldl\n          ', end='')
        print('eps_abs = %.2e, eps_rel = %.2e,' % (settings.eps_abs, settings.eps_rel))
        print('          eps_prim_inf = %.2e, eps_dual_inf = %.2e,' % (settings.eps_prim_inf, settings.eps_dual_inf))
        print('          rho = %.2e ' % settings.rho, end='')
        if settings.adaptive_rho:
            print('(adaptive)')
        else:
            print('')
        print(
            '          sigma = %.2e, alpha = %.2f, ' % (settings.sigma, settings.alpha),
            end='',
        )
        print('max_iter = %d' % settings.max_iter)
        if settings.scaling:
            print('          scaling: on, ', end='')
        else:
            print('          scaling: off, ', end='')
        if settings.scaled_termination:
            print('scaled_termination: on')
        else:
            print('scaled_termination: off')
        if settings.warm_start:
            print('          warm_start: on, ', end='')
        else:
            print('          warm_start: off, ', end='')
        if settings.polish:
            print('polish: on')
        else:
            print('polish: off')
        print('')

    def print_header(self):
        """
        Print header before the iterations
        """
        print('iter   objective    pri res    dua res    rho       time')

    def update_status(self, status):
        self.work.info.status_val = status
        if status == OSQP_SOLVED:
            self.work.info.status = 'solved'
        if status == OSQP_SOLVED_INACCURATE:
            self.work.info.status = 'solved inaccurate'
        elif status == OSQP_PRIMAL_INFEASIBLE:
            self.work.info.status = 'primal infeasible'
        elif status == OSQP_PRIMAL_INFEASIBLE_INACCURATE:
            self.work.info.status = 'primal infeasible inaccurate'
        elif status == OSQP_UNSOLVED:
            self.work.info.status = 'unsolved'
        elif status == OSQP_DUAL_INFEASIBLE:
            self.work.info.status = 'dual infeasible'
        elif status == OSQP_DUAL_INFEASIBLE_INACCURATE:
            self.work.info.status = 'dual infeasible inaccurate'
        elif status == OSQP_MAX_ITER_REACHED:
            self.work.info.status = 'maximum iterations reached'
        elif status == OSQP_NON_CVX:
            self.work.info.status = 'problem non convex'

    def cold_start(self):
        """
        Cold start optimization variables to zero
        """
        self.work.x = np.zeros(self.work.data.n)
        self.work.z = np.zeros(self.work.data.m)
        self.work.y = np.zeros(self.work.data.m)

    def update_xz_tilde(self):
        """
        First ADMM step: update xz_tilde
        """
        # Compute rhs and store it in xz_tilde
        self.work.xz_tilde[: self.work.data.n] = self.work.settings.sigma * self.work.x_prev - self.work.data.q
        self.work.xz_tilde[self.work.data.n :] = self.work.z_prev - self.work.rho_inv_vec * self.work.y

        # Solve linear system
        self.work.xz_tilde = self.work.linsys_solver.solve(self.work.xz_tilde)

        # Update z_tilde
        self.work.xz_tilde[self.work.data.n :] = self.work.z_prev + self.work.rho_inv_vec * (
            self.work.xz_tilde[self.work.data.n :] - self.work.y
        )

    def update_x(self):
        """
        Update x variable in second ADMM step
        """
        self.work.x = (
            self.work.settings.alpha * self.work.xz_tilde[: self.work.data.n]
            + (1.0 - self.work.settings.alpha) * self.work.x_prev
        )
        self.work.delta_x = self.work.x - self.work.x_prev

    def project(self, z):
        """
        Project z variable in set C (for now C = [l, u])
        """
        return np.minimum(np.maximum(z, self.work.data.l), self.work.data.u)

    def project_normalcone(self, z, y):
        tmp = z + y
        z = np.minimum(np.maximum(tmp, self.work.data.l), self.work.data.u)
        y = tmp - z
        return z, y

    def update_z(self):
        """
        Update z variable in second ADMM step
        """
        self.work.z = (
            self.work.settings.alpha * self.work.xz_tilde[self.work.data.n :]
            + (1.0 - self.work.settings.alpha) * self.work.z_prev
            + self.work.rho_inv_vec * self.work.y
        )

        self.work.z = self.project(self.work.z)

    def update_y(self):
        """
        Third ADMM step: update dual variable y
        """
        self.work.delta_y = self.work.rho_vec * (
            self.work.settings.alpha * self.work.xz_tilde[self.work.data.n :]
            + (1.0 - self.work.settings.alpha) * self.work.z_prev
            - self.work.z
        )
        self.work.y += self.work.delta_y

    def compute_obj_val(self, x):
        # Compute quadratic objective value for the given x
        obj_val = 0.5 * np.dot(x, self.work.data.P.dot(x)) + np.dot(self.work.data.q, x)

        if self.work.settings.scaling:
            obj_val *= self.work.scaling.cinv

        return obj_val

    def compute_pri_res(self, x, z):
        """
        Compute primal residual ||Ax - z||
        """

        # Primal residual
        Ax = self.work.data.A.dot(x)
        pri_res = Ax - z

        if self.work.settings.scaling and not self.work.settings.scaled_termination:
            pri_res = self.work.scaling.Einv.dot(pri_res)

        return la.norm(pri_res, np.inf)

    def compute_pri_tol(self, eps_abs, eps_rel):
        """
        Compute primal tolerance using problem data
        """
        A = self.work.data.A
        if self.work.settings.scaling and not self.work.settings.scaled_termination:
            Einv = self.work.scaling.Einv
            max_rel_eps = np.max(
                [
                    la.norm(Einv.dot(A.dot(self.work.x)), np.inf),
                    la.norm(Einv.dot(self.work.z), np.inf),
                ]
            )
        else:
            max_rel_eps = np.max(
                [
                    la.norm(A.dot(self.work.x), np.inf),
                    la.norm(self.work.z, np.inf),
                ]
            )

        eps_pri = eps_abs + eps_rel * max_rel_eps

        return eps_pri

    def compute_dua_res(self, x, y):
        """
        Compute dual residual ||Px + q + A'y||
        """

        dua_res = self.work.data.P.dot(x) + self.work.data.q + self.work.data.A.T.dot(y)

        if self.work.settings.scaling and not self.work.settings.scaled_termination:
            # Use unscaled residual
            dua_res = self.work.scaling.cinv * self.work.scaling.Dinv.dot(dua_res)

        return la.norm(dua_res, np.inf)

    def compute_dua_tol(self, eps_abs, eps_rel):
        """
        Compute dual tolerance
        """
        P = self.work.data.P
        q = self.work.data.q
        A = self.work.data.A
        if self.work.settings.scaling and not self.work.settings.scaled_termination:
            cinv = self.work.scaling.cinv
            Dinv = self.work.scaling.Dinv
            max_rel_eps = cinv * np.max(
                [
                    la.norm(Dinv.dot(A.T.dot(self.work.y)), np.inf),
                    la.norm(Dinv.dot(P.dot(self.work.x)), np.inf),
                    la.norm(Dinv.dot(q), np.inf),
                ]
            )
        else:
            max_rel_eps = np.max(
                [
                    la.norm(A.T.dot(self.work.y), np.inf),
                    la.norm(P.dot(self.work.x), np.inf),
                    la.norm(q, np.inf),
                ]
            )

        eps_dua = eps_abs + eps_rel * max_rel_eps

        return eps_dua

    def is_primal_infeasible(self, eps_prim_inf):
        """
        Check primal infeasibility
                ||A'*v||_2 = 0
        with v = delta_y/||delta_y||_2 given that following condition holds
            u'*(v)_{+} + l'*(v)_{-} < 0
        """

        # Rescale delta_y
        if self.work.settings.scaling and not self.work.settings.scaled_termination:
            norm_delta_y = la.norm(self.work.scaling.E.dot(self.work.delta_y), np.inf)
        else:
            norm_delta_y = la.norm(self.work.delta_y, np.inf)

        if norm_delta_y > eps_prim_inf:
            lhs = self.work.data.u.dot(np.maximum(self.work.delta_y, 0)) + self.work.data.l.dot(
                np.minimum(self.work.delta_y, 0)
            )
            if lhs < -eps_prim_inf * norm_delta_y:
                self.work.Atdelta_y = self.work.data.A.T.dot(self.work.delta_y)
                if self.work.settings.scaling and not self.work.settings.scaled_termination:
                    self.work.Atdelta_y = self.work.scaling.Dinv.dot(self.work.Atdelta_y)
                return la.norm(self.work.Atdelta_y, np.inf) < eps_prim_inf * norm_delta_y

        return False

    def is_dual_infeasible(self, eps_dual_inf):
        """
        Check dual infeasibility
            ||P*v||_inf = 0
        with v = delta_x / ||delta_x||_inf given that the following
        conditions hold
            q'* v < 0 and
                        | 0     if l_i, u_i in R
            (A * v)_i = { >= 0  if u_i = +inf
                        | <= 0  if l_i = -inf
        """
        # Rescale delta_x
        if self.work.settings.scaling and not self.work.settings.scaled_termination:
            norm_delta_x = la.norm(self.work.scaling.D.dot(self.work.delta_x), np.inf)
            scale_cost = self.work.scaling.c
        else:
            norm_delta_x = la.norm(self.work.delta_x, np.inf)
            scale_cost = 1.0

        # Prevent 0 division
        if norm_delta_x > eps_dual_inf:
            # First check q'* delta_x < 0
            if self.work.data.q.dot(self.work.delta_x) < -scale_cost * eps_dual_inf * norm_delta_x:
                # Compute P * delta_x
                self.work.Pdelta_x = self.work.data.P.dot(self.work.delta_x)

                # Scale if necessary
                if self.work.settings.scaling and not self.work.settings.scaled_termination:
                    self.work.Pdelta_x = self.work.scaling.Dinv.dot(self.work.Pdelta_x)

                # Check if ||P * delta_x|| = 0
                if la.norm(self.work.Pdelta_x, np.inf) < scale_cost * eps_dual_inf * norm_delta_x:
                    # Compute A * delta_x
                    self.work.Adelta_x = self.work.data.A.dot(self.work.delta_x)

                    # Scale if necessary
                    if self.work.settings.scaling and not self.work.settings.scaled_termination:
                        self.work.Adelta_x = self.work.scaling.Einv.dot(self.work.Adelta_x)

                    for i in range(self.work.data.m):
                        # De Morgan's Law applied to negate
                        # conditions on A * delta_x
                        if (
                            (self.work.data.u[i] < OSQP_INFTY * MIN_SCALING)
                            and (self.work.Adelta_x[i] > eps_dual_inf * norm_delta_x)
                        ) or (
                            (self.work.data.l[i] > -OSQP_INFTY * MIN_SCALING)
                            and (self.work.Adelta_x[i] < -eps_dual_inf * norm_delta_x)
                        ):
                            # At least one condition not satisfied
                            return False

                    # All conditions passed -> dual infeasible
                    return True

        # No all checks managed to pass. Problem not dual infeasible
        return False

    def compute_rho_estimate(self):
        # Iterates
        x = self.work.x
        y = self.work.y
        z = self.work.z

        # Problem data
        P = self.work.data.P
        q = self.work.data.q
        A = self.work.data.A

        # Compute normalized residuals
        pri_res = la.norm(A.dot(x) - z, np.inf)
        pri_res /= np.max([la.norm(A.dot(x), np.inf), la.norm(z, np.inf)]) + 1e-10
        dua_res = la.norm(P.dot(x) + q + A.T.dot(y), np.inf)
        dua_res /= (
            np.max(
                [
                    la.norm(A.T.dot(y), np.inf),
                    la.norm(P.dot(x), np.inf),
                    la.norm(q, np.inf),
                ]
            )
            + 1e-10
        )

        # Compute new rho
        new_rho = self.work.settings.rho * np.sqrt(pri_res / (dua_res + 1e-10))
        return min(max(new_rho, RHO_MIN), RHO_MAX)

    def adapt_rho(self):
        """
        Adapt rho value based on current primal and dual residuals
        """
        # Compute new rho
        rho_new = self.compute_rho_estimate()

        # Update rho estimate
        self.work.info.rho_estimate = rho_new

        # Settings
        adaptive_rho_tolerance = self.work.settings.adaptive_rho_tolerance

        if (
            rho_new > adaptive_rho_tolerance * self.work.settings.rho
            or rho_new < 1.0 / adaptive_rho_tolerance * self.work.settings.rho
        ):
            # Update rho
            self.update_rho(rho_new)
            # Update rho updates count
            self.work.info.rho_updates += 1

    def reset_info(self, info):
        """
        Reset information after problem updates
        """
        info.solve_time = 0.0
        info.polish_time = 0.0

        self.update_status(OSQP_UNSOLVED)

        info.rho_updates = 0

    def update_info(self, iter, polish):
        """
        Update information at iterations
        """

        if polish == 1:
            self.work.pol.obj_val = self.compute_obj_val(self.work.pol.x)
            self.work.pol.pri_res = self.compute_pri_res(self.work.pol.x, self.work.pol.z)
            self.work.pol.dua_res = self.compute_dua_res(self.work.pol.x, self.work.pol.y)
            self.work.info.polish_time = time.time() - self.work.timer
        else:
            self.work.info.iter = iter
            self.work.info.obj_val = self.compute_obj_val(self.work.x)
            self.work.info.pri_res = self.compute_pri_res(self.work.x, self.work.z)
            self.work.info.dua_res = self.compute_dua_res(self.work.x, self.work.y)
            self.work.info.solve_time = time.time() - self.work.timer

    def print_summary(self):
        """
        Print status summary at each ADMM iteration
        """
        if self.work.first_run:
            runtime = self.work.info.setup_time + self.work.info.solve_time
        else:
            runtime = self.work.info.update_time + self.work.info.solve_time
        print(
            '%4i  %11.4e   %8.2e   %8.2e   %8.2e  %8.2es'
            % (
                self.work.info.iter,
                self.work.info.obj_val,
                self.work.info.pri_res,
                self.work.info.dua_res,
                self.work.settings.rho,
                runtime,
            )
        )

    def print_polish(self):
        """
        Print polish information
        """
        if self.work.first_run:
            runtime = self.work.info.setup_time + self.work.info.solve_time + self.work.info.polish_time
        else:
            runtime = self.work.info.update_time + self.work.info.solve_time + self.work.info.polish_time
        print(
            'plsh  %11.4e   %8.2e   %8.2e   --------  %8.2es'
            % (
                self.work.info.obj_val,
                self.work.info.pri_res,
                self.work.info.dua_res,
                runtime,
            )
        )

    def check_termination(self, approximate=False):
        """
        Check residuals for algorithm convergence and update solver status

        Args
        ----
            approximate: bool to determine if termination criteria are
                         approximate or accurate

        """
        pri_check = 0
        dua_check = 0
        prim_inf_check = 0
        dual_inf_check = 0

        eps_abs = self.work.settings.eps_abs
        eps_rel = self.work.settings.eps_rel
        eps_prim_inf = self.work.settings.eps_prim_inf
        eps_dual_inf = self.work.settings.eps_dual_inf

        if approximate:
            eps_abs *= 10
            eps_rel *= 10
            eps_prim_inf *= 10
            eps_dual_inf *= 10

        # If residuals are too large, the problem is probably non convex
        if (self.work.info.pri_res > OSQP_INFTY) or (self.work.info.dua_res > OSQP_INFTY):
            self.work.info.status_val = OSQP_NON_CVX
            self.work.info.obj_val = OSQP_NAN
            return 1

        if self.work.data.m == 0:  # No constraints -> always  primal feasible
            pri_check = 1
        else:
            # Compute primal tolerance
            eps_pri = self.compute_pri_tol(eps_abs, eps_rel)

            if self.work.info.pri_res < eps_pri:
                pri_check = 1
            else:
                # Check infeasibility
                prim_inf_check = self.is_primal_infeasible(eps_prim_inf)

        # Compute dual tolerance
        eps_dua = self.compute_dua_tol(eps_abs, eps_rel)

        if self.work.info.dua_res < eps_dua:
            dua_check = 1
        else:
            # Check dual infeasibility
            dual_inf_check = self.is_dual_infeasible(eps_dual_inf)

        # Compare residuals and determine solver status
        if pri_check & dua_check:
            if approximate:
                self.work.info.status_val = OSQP_SOLVED_INACCURATE
            else:
                self.work.info.status_val = OSQP_SOLVED
            return 1
        elif prim_inf_check:
            if approximate:
                self.work.info.status_val = OSQP_PRIMAL_INFEASIBLE_INACCURATE
            else:
                self.work.info.status_val = OSQP_PRIMAL_INFEASIBLE
            self.work.info.obj_val = OSQP_INFTY
            # Store original certificate
            if self.work.settings.scaling and not self.work.settings.scaled_termination:
                self.work.delta_y = self.work.scaling.E.dot(self.work.delta_y)
            return 1
        elif dual_inf_check:
            if approximate:
                self.work.info.status_val = OSQP_DUAL_INFEASIBLE_INACCURATE
            else:
                self.work.info.status_val = OSQP_DUAL_INFEASIBLE
            # Store original certificate
            if self.work.settings.scaling and not self.work.settings.scaled_termination:
                self.work.delta_x = self.work.scaling.D.dot(self.work.delta_x)
            self.work.info.obj_val = -OSQP_INFTY
            return 1

    def print_footer(self):
        """
        Print footer at the end of the optimization
        """
        print('')  # Add space after iterations
        print('status:               %s' % self.work.info.status)
        if self.work.settings.polish and self.work.info.status_val == OSQP_SOLVED:
            if self.work.info.status_polish == 1:
                print('solution polish:      successful')
            elif self.work.info.status_polish == -1:
                print('solution polish:      unsuccessful')
        print('number of iterations: %d' % self.work.info.iter)
        if self.work.info.status_val == OSQP_SOLVED or self.work.info.status_val == OSQP_SOLVED_INACCURATE:
            print('optimal objective:    %.4f' % self.work.info.obj_val)
            print('run time:             %.2es' % (self.work.info.run_time))
        print('optimal rho estimate: %.2es' % (self.work.info.rho_estimate))

        print('')  # Print last space

    def store_solution(self):
        """
        Store current primal and dual solution in solution structure
        """

        if (self.work.info.status_val is not OSQP_PRIMAL_INFEASIBLE) and (
            self.work.info.status_val is not OSQP_DUAL_INFEASIBLE
        ):
            self.work.solution.x = self.work.x
            self.work.solution.y = self.work.y

            # Unscale solution
            if self.work.settings.scaling:
                self.work.solution.x = self.work.scaling.D.dot(self.work.solution.x)
                self.work.solution.y = self.work.scaling.cinv * self.work.scaling.E.dot(self.work.solution.y)
        else:
            self.work.solution.x = np.array([None] * self.work.data.n)
            self.work.solution.y = np.array([None] * self.work.data.m)

    #
    #   Main Solver API
    #

    def setup(self, dims, Pdata, Pindices, Pindptr, q, Adata, Aindices, Aindptr, l, u, **stgs):
        """
        Perform OSQP solver setup QP problem of the form
            minimize	1/2 x' P x + q' x
            subject to	l <= A x <= u

        """
        (n, m) = dims
        self.work = workspace()

        # Start timer
        self.work.timer = time.time()

        # Unscaled problem data
        self.work.data = problem((n, m), Pdata, Pindices, Pindptr, q, Adata, Aindices, Aindptr, l, u)

        # Vectorized rho parameter
        self.work.rho_vec = np.zeros(m)
        self.work.rho_inv_vec = np.zeros(m)

        # Type of constraints
        self.work.constr_type = np.zeros(m)

        # Initialize workspace variables
        self.work.x = np.zeros(n)
        self.work.z = np.zeros(m)
        self.work.xz_tilde = np.zeros(n + m)
        self.work.x_prev = np.zeros(n)
        self.work.z_prev = np.zeros(m)
        self.work.y = np.zeros(m)
        self.work.delta_y = np.zeros(m)  # Delta_y for primal infeasibility

        # Flag indicating first run
        self.work.first_run = 1

        # Flag indicating that the update_time should be set to zero
        self.work.clear_update_time = 0

        # Settings
        self.work.settings = settings(**stgs)

        # Scale problem
        if self.work.settings.scaling:
            self.scale_data()

        # Set type of constraints
        self.set_rho_vec()

        # Factorize KKT
        self.work.linsys_solver = linsys_solver(self.work)

        # Solution
        self.work.solution = solution()

        # Info
        self.work.info = info()

        # Polishing structure
        self.work.pol = polish()

        # End timer
        self.work.info.setup_time = time.time() - self.work.timer

        # Print setup header
        if self.work.settings.verbose:
            self.print_setup_header(self.work.data, self.work.settings)

    def solve(self):
        """
        Solve QP problem using OSQP
        """
        # Start timer
        self.work.timer = time.time()

        # Clear update_time
        if self.work.clear_update_time == 1:
            self.work.info.update_time = 0.0

        # Print header
        if self.work.settings.verbose:
            self.print_header()

        # Cold start if not warm start
        if not self.work.settings.warm_start:
            self.cold_start()

        # ADMM algorithm
        for iter in range(1, self.work.settings.max_iter + 1):
            # Update x_prev, z_prev
            self.work.x_prev = np.copy(self.work.x)
            self.work.z_prev = np.copy(self.work.z)

            # Admm steps
            # First step: update \tilde{x} and \tilde{z}
            self.update_xz_tilde()

            # Second step: update x and z
            self.update_x()

            self.update_z()

            # Third step: update y
            self.update_y()

            if self.work.settings.check_termination:
                # Update info
                self.update_info(iter, 0)

                # Print summary
                if (self.work.settings.verbose) & ((iter % PRINT_INTERVAL == 0) | (iter == 1)):
                    self.print_summary()

                # Break if converged
                if self.check_termination():
                    break

            # If not terminated, update rho in case
            if (
                self.work.settings.adaptive_rho_interval
                and (iter % self.work.settings.adaptive_rho_interval == 0)
                and self.work.settings.adaptive_rho
            ):
                self.adapt_rho()
                # DEBUG: Print
                #  if self.work.settings.verbose:
                #      print("rho = %.2e" % self.work.settings.rho)

        if not self.work.settings.check_termination:
            # Update info
            self.update_info(self.work.settings.max_iter, 0)

            # Print summary
            if self.work.settings.verbose:
                self.print_summary()

            # Break if converged
            self.check_termination()

        # Print summary for last iteration
        if (self.work.settings.verbose) & (iter % PRINT_INTERVAL != 0):
            self.print_summary()

        # If max iterations reached, update status accordingly
        if iter == self.work.settings.max_iter:
            if not self.check_termination(approximate=True):
                self.work.info.status_val = OSQP_MAX_ITER_REACHED

        # Update status string
        self.update_status(self.work.info.status_val)

        # Update solve time
        self.work.info.solve_time = time.time() - self.work.timer

        # Update rho estimate
        self.work.info.rho_estimate = self.compute_rho_estimate()

        # Solution polish
        if self.work.settings.polish and self.work.info.status_val == OSQP_SOLVED:
            ls = self.polish()
        else:
            ls = None

        # Update total times
        if self.work.first_run:
            self.work.info.run_time = self.work.info.setup_time + self.work.info.solve_time + self.work.info.polish_time
        else:
            self.work.info.run_time = (
                self.work.info.update_time + self.work.info.solve_time + self.work.info.polish_time
            )

        # Print footer
        if self.work.settings.verbose:
            self.print_footer()

        # Store solution
        self.store_solution()

        # Eliminate first run flag
        if self.work.first_run:
            self.work.first_run = 0

        # Indicate that the update_time should be set to zero
        self.work.clear_update_time = 1

        # Store results structure
        return results(self.work.solution, self.work.info, ls)

    #
    #   Auxiliary API Functions
    #

    def update_lin_cost(self, q_new):
        """
        Update linear cost without requiring factorization
        """
        if self.work.clear_update_time == 1:
            # Clear update_time
            self.work.clear_update_time = 0
            self.work.info.update_time = 0.0
        # Start timer
        self.work.timer = time.time()

        # Copy cost vector
        self.work.data.q = np.copy(q_new)

        # Scaling
        if self.work.settings.scaling:
            self.work.data.q = self.work.scaling.c * self.work.scaling.D.dot(self.work.data.q)

        # Reset solver info
        self.reset_info(self.work.info)

        # Update update_time
        self.work.info.update_time += time.time() - self.work.timer

    def update_bounds(self, l_new, u_new):
        """
        Update counstraint bounds without requiring factorization
        """
        if self.work.clear_update_time == 1:
            # Clear update_time
            self.work.clear_update_time = 0
            self.work.info.update_time = 0.0
        # Start timer
        self.work.timer = time.time()

        # Check if bounds are correct
        if not np.greater_equal(u_new, l_new).all():
            raise ValueError('Lower bound must be lower than' + ' or equal to upper bound!')

        # Update vectors
        self.work.data.l = np.copy(l_new)
        self.work.data.u = np.copy(u_new)

        # Scale vectors
        if self.work.settings.scaling:
            self.work.data.l = self.work.scaling.E.dot(self.work.data.l)
            self.work.data.u = self.work.scaling.E.dot(self.work.data.u)

        # Reset solver info
        self.reset_info(self.work.info)

        # If type of any constraint changed, update rho_vec and KKT matrix
        self.update_rho_vec()

        # Update update_time
        self.work.info.update_time += time.time() - self.work.timer

    def update_lower_bound(self, l_new):
        """
        Update lower bound without requiring factorization
        """
        if self.work.clear_update_time == 1:
            # Clear update_time
            self.work.clear_update_time = 0
            self.work.info.update_time = 0.0
        # Start timer
        self.work.timer = time.time()

        # Update lower bound
        self.work.data.l = l_new

        # Scale vector
        if self.work.settings.scaling:
            self.work.data.l = self.work.scaling.E.dot(self.work.data.l)

        # Check values
        if not np.greater_equal(self.work.data.u, self.work.data.l).all():
            raise ValueError('Lower bound must be lower than' + ' or equal to upper bound!')

        # Reset solver info
        self.reset_info(self.work.info)

        # If type of any constraint changed, update rho_vec and KKT matrix
        self.update_rho_vec()

        # Update update_time
        self.work.info.update_time += time.time() - self.work.timer

    def update_upper_bound(self, u_new):
        """
        Update upper bound without requiring factorization
        """
        if self.work.clear_update_time == 1:
            # Clear update_time
            self.work.clear_update_time = 0
            self.work.info.update_time = 0.0
        # Start timer
        self.work.timer = time.time()

        # Update upper bound
        self.work.data.u = u_new

        # Scale vector
        if self.work.settings.scaling:
            self.work.data.u = self.work.scaling.E.dot(self.work.data.u)

        # Check values
        if not np.greater_equal(self.work.data.u, self.work.data.l).all():
            raise ValueError('Lower bound must be lower than' + ' or equal to upper bound!')

        # Reset solver info
        self.reset_info(self.work.info)

        # If type of any constraint changed, update rho_vec and KKT matrix
        self.update_rho_vec()

        # Update update_time
        self.work.info.update_time += time.time() - self.work.timer

    def update_P(self, P_new):
        """
        Update quadratic cost matrix
        """
        if self.work.clear_update_time == 1:
            # Clear update_time
            self.work.clear_update_time = 0
            self.work.info.update_time = 0.0
        # Start timer
        self.work.timer = time.time()

        if self.work.settings.scaling:
            self.work.data.P = self.work.scaling.c * self.work.scaling.D.dot(P_new.dot(self.work.scaling.D))
        else:
            self.work.data.P = P_new
        self.work.linsys_solver = linsys_solver(self.work)

        # Update update_time
        self.work.info.update_time += time.time() - self.work.timer

    def update_A(self, A_new):
        """
        Update constraint matrix
        """
        if self.work.clear_update_time == 1:
            # Clear update_time
            self.work.clear_update_time = 0
            self.work.info.update_time = 0.0
        # Start timer
        self.work.timer = time.time()

        if self.work.settings.scaling:
            self.work.data.A = self.work.scaling.E.dot(A_new.dot(self.work.scaling.D))
        else:
            self.work.data.A = A_new
        self.work.linsys_solver = linsys_solver(self.work)

        # Update update_time
        self.work.info.update_time += time.time() - self.work.timer

    def update_P_A(self, P_new, A_new):
        """
        Update quadratic cost and constraint matrices
        """
        if self.work.clear_update_time == 1:
            # Clear update_time
            self.work.clear_update_time = 0
            self.work.info.update_time = 0.0
        # Start timer
        self.work.timer = time.time()

        if self.work.settings.scaling:
            self.work.data.P = self.work.scaling.D.dot(P_new.dot(self.work.scaling.D))
            self.work.data.A = self.work.scaling.E.dot(A_new.dot(self.work.scaling.D))
        else:
            self.work.data.P = P_new
            self.work.data.A = A_new
        self.work.linsys_solver = linsys_solver(self.work)

        # Update update_time
        self.work.info.update_time += time.time() - self.work.timer

    def warm_start(self, x, y):
        """
        Warm start primal and dual variables
        """
        # Update warm_start setting to true
        self.work.settings.warm_start = True

        # Copy primal and dual variables into the iterates
        self.work.x = x
        self.work.y = y

        # Scale iterates
        self.work.x = self.work.scaling.Dinv.dot(self.work.x)
        self.work.y = self.work.scaling.Einv.dot(self.work.y)

        # Update z iterate as well
        self.work.z = self.work.data.A.dot(self.work.x)

    def warm_start_x(self, x):
        """
        Warm start primal variable
        """
        # Update warm_start setting to true
        self.work.settings.warm_start = True

        # Copy primal and dual variables into the iterates
        self.work.x = x

        # Scale iterates
        self.work.x = self.work.scaling.Dinv.dot(self.work.x)

        # Update z iterate as well
        self.work.z = self.work.data.A.dot(self.work.x)

        # Cold start y
        self.work.y = np.zeros(self.work.data.m)

    def warm_start_y(self, y):
        """
        Warm start dual variable
        """
        # Update warm_start setting to true
        self.work.settings.warm_start = True

        # Copy primal and dual variables into the iterates
        self.work.y = y

        # Scale iterates
        self.work.y = self.work.scaling.Einv.dot(self.work.y)

        # Cold start x and z
        self.work.x = np.zeros(self.work.data.n)
        self.work.z = np.zeros(self.work.data.m)

    #
    #   Update Problem Settings
    #
    def update_max_iter(self, max_iter_new):
        """
        Update maximum number of iterations
        """
        # Check that maxiter is positive
        if max_iter_new <= 0:
            raise ValueError('max_iter must be positive')

        # Update max_iter
        self.work.settings.max_iter = max_iter_new

    def update_eps_abs(self, eps_abs_new):
        """
        Update absolute tolerance
        """
        if eps_abs_new <= 0:
            raise ValueError('eps_abs must be positive')

        self.work.settings.eps_abs = eps_abs_new

    def update_eps_rel(self, eps_rel_new):
        """
        Update relative tolerance
        """
        if eps_rel_new <= 0:
            raise ValueError('eps_rel must be positive')

        self.work.settings.eps_rel = eps_rel_new

    def update_rho(self, rho_new):
        """
        Update set-size parameter rho
        """
        if rho_new <= 0:
            raise ValueError('rho must be positive')

        # Update rho
        self.work.settings.rho = np.minimum(np.maximum(rho_new, RHO_MIN), RHO_MAX)

        # Update rho_vec and rho_inv_vec
        ineq_ind = np.where(self.work.constr_type == 0)
        eq_ind = np.where(self.work.constr_type == 1)
        self.work.rho_vec[ineq_ind] = self.work.settings.rho
        self.work.rho_vec[eq_ind] = RHO_EQ_OVER_RHO_INEQ * self.work.settings.rho
        self.work.rho_inv_vec = np.reciprocal(self.work.rho_vec)

        # Factorize KKT
        self.work.linsys_solver = linsys_solver(self.work)

    def update_alpha(self, alpha_new):
        """
        Update relaxation parameter alpga
        """
        if not (alpha_new >= 0 | alpha_new <= 2):
            raise ValueError('alpha must be between 0 and 2')

        self.work.settings.alpha = alpha_new

    def update_delta(self, delta_new):
        """
        Update delta parameter for polish
        """
        if delta_new <= 0:
            raise ValueError('delta must be positive')

        self.work.settings.delta = delta_new

    def update_polish(self, polish_new):
        """
        Update polish parameter
        """
        if (polish_new is not True) & (polish_new is not False):
            raise ValueError('polish should be either True or False')

        self.work.settings.polish = polish_new
        self.work.info.polish_time = 0.0

    def update_polish_refine_iter(self, polish_refine_iter_new):
        """
        Update number iterative refinement iterations in polish
        """
        if polish_refine_iter_new < 0:
            raise ValueError('polish_refine_iter must be nonnegative')

        self.work.settings.polish_refine_iter = polish_refine_iter_new

    def update_verbose(self, verbose_new):
        """
        Update verbose parameter
        """
        if (verbose_new is not True) & (verbose_new is not False):
            raise ValueError('verbose should be either True or False')

        self.work.settings.verbose = verbose_new

    def update_scaled_termination(self, scaled_termination_new):
        """
        Update scaled_termination parameter
        """
        if (scaled_termination_new is not True) & (scaled_termination_new is not False):
            raise ValueError('scaled_termination should be either True or False')

        self.work.settings.scaled_termination = scaled_termination_new

    def update_check_termination(self, check_termination_new):
        """
        Update check_termination parameter
        """
        if check_termination_new <= 0:
            raise ValueError('check_termination should be greater than 0')

        self.work.settings.check_termination = check_termination_new

    def update_warm_start(self, warm_start_new):
        """
        Update warm_start parameter
        """
        if (warm_start_new is not True) & (warm_start_new is not False):
            raise ValueError('warm_start should be either True or False')

        self.work.settings.warm_start = warm_start_new

    def constant(self, constant_name):
        """
        Return solver constant
        """
        if constant_name == 'OSQP_INFTY':
            return OSQP_INFTY
        if constant_name == 'OSQP_NAN':
            return OSQP_NAN
        if constant_name == 'OSQP_SOLVED':
            return OSQP_SOLVED
        if constant_name == 'OSQP_UNSOLVED':
            return OSQP_UNSOLVED
        if constant_name == 'OSQP_PRIMAL_INFEASIBLE':
            return OSQP_PRIMAL_INFEASIBLE
        if constant_name == 'OSQP_DUAL_INFEASIBLE':
            return OSQP_DUAL_INFEASIBLE
        if constant_name == 'OSQP_MAX_ITER_REACHED':
            return OSQP_MAX_ITER_REACHED

        raise ValueError('Constant not recognized!')

    def iter_refin(self, KKT_factor, z, b):
        """
        Iterative refinement of the solution of a linear system
            1. (K + dK) * dz = b - K*z
            2. z <- z + dz
        """
        for i in range(self.work.settings.polish_refine_iter):
            rhs = b - np.hstack(
                [
                    self.work.data.P.dot(z[: self.work.data.n]) + self.work.pol.Ared.T.dot(z[self.work.data.n :]),
                    self.work.pol.Ared.dot(z[: self.work.data.n]),
                ]
            )
            dz = KKT_factor.solve(rhs)
            z += dz
        return z

    def polish(self):
        """
        Solution polish:
        Solve equality constrained QP with assumed active constraints.
        """
        # Start timer
        self.work.timer = time.time()

        # Guess which linear constraints are lower-active, upper-active, free
        self.work.pol.ind_low = np.where(self.work.z - self.work.data.l < -self.work.y)[0]
        self.work.pol.ind_upp = np.where(self.work.data.u - self.work.z < self.work.y)[0]
        self.work.pol.n_low = len(self.work.pol.ind_low)
        self.work.pol.n_upp = len(self.work.pol.ind_upp)

        # Form Ared from the assumed active constraints
        self.work.pol.Ared = spspa.vstack(
            [
                self.work.data.A[self.work.pol.ind_low],
                self.work.data.A[self.work.pol.ind_upp],
            ]
        )

        # # Terminate if there are no active constraints
        # if self.work.pol.Ared.shape[0] == 0:
        #     return

        # Form and factorize reduced KKT
        KKTred = spspa.vstack(
            [
                spspa.hstack(
                    [
                        self.work.data.P + self.work.settings.delta * spspa.eye(self.work.data.n),
                        self.work.pol.Ared.T,
                    ]
                ),
                spspa.hstack(
                    [
                        self.work.pol.Ared,
                        -self.work.settings.delta * spspa.eye(self.work.pol.Ared.shape[0]),
                    ]
                ),
            ]
        )
        KKTred_factor = spla.splu(KKTred.tocsc())

        # Form reduced RHS
        rhs_red = np.hstack(
            [
                -self.work.data.q,
                self.work.data.l[self.work.pol.ind_low],
                self.work.data.u[self.work.pol.ind_upp],
            ]
        )

        # Solve reduced KKT system
        pol_sol = KKTred_factor.solve(rhs_red)

        # Perform iterative refinement to compensate for the reg. error
        if self.work.settings.polish_refine_iter > 0:
            pol_sol = self.iter_refin(KKTred_factor, pol_sol, rhs_red)

        # Store the polished solution (x,z,y)
        self.work.pol.x = pol_sol[: self.work.data.n]
        self.work.pol.z = self.work.data.A.dot(self.work.pol.x)
        self.work.pol.y = np.zeros(self.work.data.m)
        y_red = pol_sol[self.work.data.n :]
        self.work.pol.y[self.work.pol.ind_low] = y_red[: self.work.pol.n_low]
        self.work.pol.y[self.work.pol.ind_upp] = y_red[self.work.pol.n_low :]

        # Ensure (z,y) satisfies normal cone constraint
        self.work.pol.z, self.work.pol.y = self.project_normalcone(self.work.pol.z, self.work.pol.y)

        # Compute primal and dual residuals of the polished solution
        self.update_info(0, 1)

        # Check if polish was successful
        pol_success = (
            (self.work.pol.pri_res < self.work.info.pri_res)
            and (self.work.pol.dua_res < self.work.info.dua_res)
            or (self.work.pol.pri_res < self.work.info.pri_res)
            and (self.work.info.dua_res < 1e-10)
            or (self.work.pol.dua_res < self.work.info.dua_res)
            and (self.work.info.pri_res < 1e-10)
        )

        ls = linesearch()

        if pol_success:
            # Update solver information
            self.work.info.obj_val = self.work.pol.obj_val
            self.work.info.pri_res = self.work.pol.pri_res
            self.work.info.dua_res = self.work.pol.dua_res
            self.work.info.status_polish = 1

            # Update ADMM iterations
            self.work.x = self.work.pol.x
            self.work.z = self.work.pol.z
            self.work.y = self.work.pol.y

            # Print summary
            if self.work.settings.verbose:
                self.print_polish()

        else:
            self.work.info.status_polish = -1

            # Line search on the line connecting the ADMM and the polished sol.
            ls.t = np.linspace(0.0, 0.002, 1000)
            ls.X, ls.Z, ls.Y = self.line_search(
                self.work.x,
                self.work.z,
                self.work.y,
                self.work.pol.x,
                self.work.pol.z,
                self.work.pol.y,
                ls.t,
            )

        return ls

    def line_search(self, x1, z1, y1, x2, z2, y2, t):
        """
        Perform line search on the line between (x1,z1,y1) and (x2,z2,y2).
        """
        N = len(t)
        X = np.zeros((N, self.work.data.n))
        Z = np.zeros((N, self.work.data.m))
        Y = np.zeros((N, self.work.data.m))

        dx = x2 - x1
        dz = z2 - z1
        dy = y2 - y1

        for i in range(N):
            X[i, :] = x1 + t[i] * dx
            Z[i, :] = z1 + t[i] * dz
            Y[i, :] = y1 + t[i] * dy
            Z[i, :], Y[i, :] = self.project_normalcone(Z[i, :], Y[i, :])

            # Unscale optimization variables (x,z,y)
            if self.work.settings.scaling:
                X[i, :] = self.work.scaling.D.dot(X[i, :])
                Z[i, :] = self.work.scaling.Einv.dot(Z[i, :])
                Y[i, :] = self.work.scaling.E.dot(Y[i, :])

        return (X, Z, Y)
